-
Notifications
You must be signed in to change notification settings - Fork 12
Initialize Cutlass-SYCL support #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a094121 to
d5a32ec
Compare
4319c23 to
6ad98d8
Compare
| if cu_seqlens_q == None: # !is_varlen_q | ||
| cu_seqlens_q = torch.arange( | ||
| 0, q.size(0) + 1, dtype=torch.int, device=q.device | ||
| ) * q.size(1) | ||
| max_seqlen_q = q.size(1) | ||
| q = q.view(-1, q.size(-2), q.size(-1)).contiguous() | ||
| if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new | ||
| cu_seqlens_k_new = torch.arange( | ||
| 0, k.size(0) + 1, dtype=torch.int, device=k.device | ||
| ) | ||
| elif k is None: | ||
| cu_seqlens_k_new = torch.zeros_like( | ||
| cu_seqlens_q, dtype=torch.int32, device=q.device | ||
| ) | ||
| if cache_seqlens is not None: | ||
| max_seqlen_k = cache_seqlens.max().item() | ||
| assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) | ||
| max_page_size_per_seq = page_table.size(1) | ||
| num_pages_per_seq = torch.arange( | ||
| 0, | ||
| cache_seqlens.size(0) * max_page_size_per_seq, | ||
| max_page_size_per_seq, | ||
| device=cache_seqlens.device, | ||
| ).to(torch.int32) | ||
| cu_seqlens_k = torch.concat( | ||
| ( | ||
| torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), | ||
| torch.cumsum(cache_seqlens, 0), | ||
| ) | ||
| ).to(torch.int32) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these ops are causing perf degrade compared to triton
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries, we are aware of this. this PR still needs a lot of change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't have to pay too much attention for it right now, will be fixed later.
include/sgl_flash_kernel_ops.h
Outdated
| std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h) | ||
| std::optional<at::Tensor>& k_descale_, // (b, h_k) | ||
| std::optional<at::Tensor>& v_descale_, // (b, h_k) | ||
| std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we changing function signature ?
a05d6ce to
f5c2c89
Compare
| if cu_seqlens_q == None: # !is_varlen_q | ||
| cu_seqlens_q = torch.arange( | ||
| 0, q.size(0) + 1, dtype=torch.int, device=q.device | ||
| ) * q.size(1) | ||
| max_seqlen_q = q.size(1) | ||
| q = q.view(-1, q.size(-2), q.size(-1)).contiguous() | ||
| batch_size = cu_seqlens_q.numel() - 1 | ||
| page_table = ( | ||
| torch.arange(0, batch_size, device=q.device) | ||
| .to(torch.int32) | ||
| .reshape([batch_size, 1]) | ||
| .contiguous() | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what extra functionality we are trying to provide ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current kernel implementation are align between vllm and sglang requests, so there will be some changes on the sglang side.”
src/sycl/chunked_prefill.cpp
Outdated
| #include "cutlass/util/device_memory.h" | ||
| #include "cutlass/util/packed_stride.hpp" | ||
| #include "cutlass/util/reference/device/gemm_complex.h" | ||
| #include "cutlass/util/reference/device/tensor_compare.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need header files of verify code
| if (params.page_table != nullptr && params.cu_seqlens_k != nullptr) { | ||
| return run<true, true, cutlass::flash_attention::IndividualScheduler>(params); | ||
| } else { | ||
| return 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only use page_kv?
src/sycl/chunked_prefill.cpp
Outdated
| CHECK_DEVICE(v_new); | ||
| TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); | ||
| TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); | ||
| int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seqlen_kv_new =1? or 0
src/sycl/chunked_prefill.cpp
Outdated
| at::Tensor out_accum, softmax_lse_accum; | ||
| auto outaccum_type = at::ScalarType::Float; | ||
|
|
||
| constexpr int PipelineStages = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set 2
d2e195a to
67a20fe
Compare
src/sycl/chunked_prefill.cpp
Outdated
| #define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU") | ||
| #define CHECK_SHAPE(x, ...) \ | ||
| TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") | ||
| #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if cu_seqlens_q == None: # !is_varlen_q | ||
| cu_seqlens_q = torch.arange( | ||
| 0, q.size(0) + 1, dtype=torch.int, device=q.device | ||
| ) * q.size(1) | ||
| max_seqlen_q = q.size(1) | ||
| q = q.view(-1, q.size(-2), q.size(-1)).contiguous() | ||
| if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new | ||
| cu_seqlens_k_new = torch.arange( | ||
| 0, k.size(0) + 1, dtype=torch.int, device=k.device | ||
| ) | ||
| elif k is None: | ||
| cu_seqlens_k_new = torch.zeros_like( | ||
| cu_seqlens_q, dtype=torch.int32, device=q.device | ||
| ) | ||
| if cache_seqlens is not None: | ||
| max_seqlen_k = cache_seqlens.max().item() | ||
| assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) | ||
| max_page_size_per_seq = page_table.size(1) | ||
| num_pages_per_seq = torch.arange( | ||
| 0, | ||
| cache_seqlens.size(0) * max_page_size_per_seq, | ||
| max_page_size_per_seq, | ||
| device=cache_seqlens.device, | ||
| ).to(torch.int32) | ||
| cu_seqlens_k = torch.concat( | ||
| ( | ||
| torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), | ||
| torch.cumsum(cache_seqlens, 0), | ||
| ) | ||
| ).to(torch.int32) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries, we are aware of this. this PR still needs a lot of change.
| if cu_seqlens_q == None: # !is_varlen_q | ||
| cu_seqlens_q = torch.arange( | ||
| 0, q.size(0) + 1, dtype=torch.int, device=q.device | ||
| ) * q.size(1) | ||
| max_seqlen_q = q.size(1) | ||
| q = q.view(-1, q.size(-2), q.size(-1)).contiguous() | ||
| if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new | ||
| cu_seqlens_k_new = torch.arange( | ||
| 0, k.size(0) + 1, dtype=torch.int, device=k.device | ||
| ) | ||
| elif k is None: | ||
| cu_seqlens_k_new = torch.zeros_like( | ||
| cu_seqlens_q, dtype=torch.int32, device=q.device | ||
| ) | ||
| if cache_seqlens is not None: | ||
| max_seqlen_k = cache_seqlens.max().item() | ||
| assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) | ||
| max_page_size_per_seq = page_table.size(1) | ||
| num_pages_per_seq = torch.arange( | ||
| 0, | ||
| cache_seqlens.size(0) * max_page_size_per_seq, | ||
| max_page_size_per_seq, | ||
| device=cache_seqlens.device, | ||
| ).to(torch.int32) | ||
| cu_seqlens_k = torch.concat( | ||
| ( | ||
| torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), | ||
| torch.cumsum(cache_seqlens, 0), | ||
| ) | ||
| ).to(torch.int32) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't have to pay too much attention for it right now, will be fixed later.
|
@airMeng fix lint |
* initialize Cutlass support Add chunked prefill op --------- Co-authored-by: Swift.Sun <[email protected]>
Add chunked prefill op. The PR works with OneAPI 2025.1 currently.
llama-3b BF16 accuracy results, verified on BMG-12GB. You need to install SGLang per instructions from https://github.com/airMeng/sglang/blob/xpu_attention/docs/platforms/xpu.md
To reproduce the accuracy results, launch the server first
Run the accuracy scripts in SGLang
The PR can't work with the current open source OneAPI due to an issue of SYCLCompat. You can update your local OneAPI according to the intel/llvm#19673